{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Writing a custom acquisition function and interfacing with Ax\n",
    "\n",
    "As seen in the [custom BoTorch model in Ax](./custom_botorch_model_in_ax) tutorial, Ax's `BotorchModel` is flexible in allowing different components of the Bayesian optimization loop to be specified through a functional API. This tutorial walks through the steps of writing a custom acquisition function and then inserting it into Ax. \n",
    "\n",
    "\n",
    "### Upper Confidence Bound (UCB)\n",
    "\n",
    "The Upper Confidence Bound (UCB) acquisition function balances exploration and exploitation by assigning a score of $\\mu + \\sqrt{\\beta} \\cdot \\sigma$ if the posterior distribution is normal with mean $\\mu$ and variance $\\sigma^2$. This \"analytic\" version is implemented in the `UpperConfidenceBound` class. The Monte Carlo version of UCB is implemented in the `qUpperConfidenceBound` class, which also allows for q-batches of size greater than one. (The derivation of q-UCB is given in Appendix A of [Wilson et. al., 2017](https://arxiv.org/pdf/1712.00424.pdf))."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### A scalarized version of q-UCB\n",
    "\n",
    "Suppose now that we are in a multi-output setting, where, e.g., we model the effects of a design on multiple metrics. We first show a simple extension of the q-UCB acquisition function that accepts a multi-output model and performs q-UCB on a scalarized version of the multiple outputs, achieved via a vector of weights. Implementing a new acquisition function in botorch is easy; one simply needs to implement the constructor and a `forward` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "from torch import Tensor\n",
    "from typing import Optional\n",
    "\n",
    "from botorch.acquisition import MCAcquisitionObjective\n",
    "from botorch.acquisition.monte_carlo import MCAcquisitionFunction\n",
    "from botorch.models.model import Model\n",
    "from botorch.sampling.samplers import MCSampler\n",
    "from botorch.utils import t_batch_mode_transform\n",
    "\n",
    "\n",
    "class qScalarizedUpperConfidenceBound(MCAcquisitionFunction):\n",
    "    def __init__(\n",
    "        self,\n",
    "        model: Model,\n",
    "        beta: Tensor,\n",
    "        weights: Tensor,\n",
    "        sampler: Optional[MCSampler] = None,\n",
    "        objective: Optional[MCAcquisitionObjective] = None,\n",
    "    ) -> None:\n",
    "        super().__init__(model=model, sampler=sampler, objective=objective)\n",
    "        self.register_buffer(\"beta\", torch.as_tensor(beta))\n",
    "        self.register_buffer(\"weights\", torch.as_tensor(weights))\n",
    "\n",
    "    @t_batch_mode_transform()\n",
    "    def forward(self, X: Tensor) -> Tensor:\n",
    "        \"\"\"Evaluate scalarized qUCB on the candidate set `X`.\n",
    "\n",
    "        Args:\n",
    "            X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim\n",
    "                design points each.\n",
    "\n",
    "        Returns:\n",
    "            Tensor: A `(b)`-dim Tensor of Upper Confidence Bound values at the\n",
    "                given design points `X`.\n",
    "        \"\"\"\n",
    "        posterior = self.model.posterior(X)\n",
    "        samples = self.sampler(posterior)  # n x b x q x o\n",
    "        scalarized_samples = samples.matmul(self.weights)  # n x b x q\n",
    "        mean = posterior.mean  # b x q x o\n",
    "        scalarized_mean = mean.matmul(self.weights)  # b x q\n",
    "        ucb_samples = (\n",
    "            scalarized_mean\n",
    "            + math.sqrt(self.beta * math.pi / 2)\n",
    "            * (scalarized_samples - scalarized_mean).abs()\n",
    "        )\n",
    "        return ucb_samples.max(dim=-1)[0].mean(dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that `qScalarizedUpperConfidenceBound` is very similar to `qUpperConfidenceBound` and only requires a few lines of new code to accomodate scalarization of multiple outputs. The `@t_batch_mode_transform` decorator ensures that the input `X` has an explicit t-batch dimension (code comments are added with shapes for clarity)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Ad-hoc testing q-Scalarized-UCB\n",
    "\n",
    "Before hooking the newly defined acquisition function into a Bayesian Optimization loop, we should test it. For this we'll just make sure that it properly evaluates on a compatible multi-output model. Here we just define a basic multi-output `SingleTaskGP` model trained on synthetic data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from botorch.fit import fit_gpytorch_model\n",
    "from botorch.models import SingleTaskGP\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "\n",
    "# generate synthetic data\n",
    "X = torch.rand(20, 2)\n",
    "Y = torch.stack([torch.sin(X[:, 0]), torch.cos(X[:, 1])], -1)\n",
    "\n",
    "# construct and fit the multi-output model\n",
    "gp = SingleTaskGP(X, Y)\n",
    "mll = ExactMarginalLogLikelihood(gp.likelihood, gp)\n",
    "fit_gpytorch_model(mll);\n",
    "\n",
    "# construct the acquisition function\n",
    "qSUCB = qScalarizedUpperConfidenceBound(gp, beta=0.1, weights=torch.tensor([0.1, 0.5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.4938], grad_fn=<MeanBackward2>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluate on single q-batch with q=3\n",
    "qSUCB(torch.rand(3, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.5833, 0.5478], grad_fn=<MeanBackward2>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# batch-evaluate on two q-batches with q=3\n",
    "qSUCB(torch.rand(2, 3, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### A scalarized version of analytic UCB (`q=1` only)\n",
    "\n",
    "We can also write an *analytic* version of UCB for a multi-output model, assuming a multivariate normal posterior and `q=1`. The new class `ScalarizedUpperConfidenceBound` subclasses `AnalyticAcquisitionFunction` instead of `MCAcquisitionFunction`. In contrast to the MC version, instead of using the weights on the MC samples, we directly scalarize the mean vector $\\mu$ and covariance matrix $\\Sigma$ and apply standard UCB on the univariate normal distribution, which has mean $w^T \\mu$ and variance $w^T \\Sigma w$. In addition to the `@t_batch_transform` decorator, here we are also using `expected_q=1` to ensure the input `X` has a `q=1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from botorch.acquisition import AnalyticAcquisitionFunction\n",
    "\n",
    "\n",
    "class ScalarizedUpperConfidenceBound(AnalyticAcquisitionFunction):\n",
    "    def __init__(\n",
    "        self,\n",
    "        model: Model,\n",
    "        beta: Tensor,\n",
    "        weights: Tensor,\n",
    "        maximize: bool = True,\n",
    "    ) -> None:\n",
    "        super().__init__(model=model)\n",
    "        self.maximize = maximize\n",
    "        self.register_buffer(\"beta\", torch.as_tensor(beta))\n",
    "        self.register_buffer(\"weights\", torch.as_tensor(weights))\n",
    "\n",
    "    @t_batch_mode_transform(expected_q=1)\n",
    "    def forward(self, X: Tensor) -> Tensor:\n",
    "        \"\"\"Evaluate the Upper Confidence Bound on the candidate set X using scalarization\n",
    "\n",
    "        Args:\n",
    "            X: A `(b) x d`-dim Tensor of `(b)` t-batches of `d`-dim design\n",
    "                points each.\n",
    "\n",
    "        Returns:\n",
    "            A `(b)`-dim Tensor of Upper Confidence Bound values at the given\n",
    "                design points `X`.\n",
    "        \"\"\"\n",
    "        self.beta = self.beta.to(X)\n",
    "        batch_shape = X.shape[:-2]\n",
    "        posterior = self.model.posterior(X)\n",
    "        means = posterior.mean.squeeze(dim=-2)  # b x o\n",
    "        scalarized_mean = means.matmul(self.weights)  # b\n",
    "        covs = posterior.mvn.covariance_matrix  # b x o x o\n",
    "        weights = self.weights.view(1, -1, 1)  # 1 x o x 1 (assume single batch dimension)\n",
    "        weights = weights.expand(batch_shape + weights.shape[1:])  # b x o x 1\n",
    "        weights_transpose = weights.permute(0, 2, 1)  # b x 1 x o\n",
    "        scalarized_variance = torch.bmm(\n",
    "            weights_transpose, torch.bmm(covs, weights)\n",
    "        ).view(batch_shape)  # b\n",
    "        delta = (self.beta.expand_as(scalarized_mean) * scalarized_variance).sqrt()\n",
    "        if self.maximize:\n",
    "            return scalarized_mean + delta\n",
    "        else:\n",
    "            return scalarized_mean - delta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Ad-hoc testing Scalarized-UCB\n",
    "\n",
    "Notice that we pass in an explicit q-batch dimension for consistency, even though `q=1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct the acquisition function\n",
    "SUCB = ScalarizedUpperConfidenceBound(gp, beta=0.1, weights=torch.tensor([0.1, 0.5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.3583], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# evaluate on single point\n",
    "SUCB(torch.rand(1, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.3743, 0.5131, 0.3595], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# batch-evaluate on 3 points\n",
    "SUCB(torch.rand(3, 1, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use our newly minted acquisition function within Ax, we need to write a custom factory function and pass it to the constructor of Ax's `BotorchModel` as the `acqf_constructor`, which has the call signature:\n",
    "\n",
    "```python\n",
    "def acqf_constructor(\n",
    "    model: Model,\n",
    "    objective_weights: Tensor,\n",
    "    outcome_constraints: Optional[Tuple[Tensor, Tensor]],\n",
    "    X_observed: Optional[Tensor] = None,\n",
    "    X_pending: Optional[Tensor] = None,\n",
    "    **kwargs: Any,\n",
    ") -> AcquisitionFunction:\n",
    "```\n",
    "\n",
    "The argument `objective_weights` allows for scalarization of multiple objectives, `outcome_constraints` is used to define constraints on multi-output models, `X_observed` contains previously observed points (useful for acquisition functions such as Noisy Expected Improvement), and `X_pending` are the points that are awaiting observations. By default, Ax uses the Noisy Expected Improvement (`qNoisyExpectedImprovement`) acquisition function and so the default value of `acqf_constructor` is `get_NEI` (see documentation for additional details and context).\n",
    "\n",
    "Note that there is ample flexibility to how the arguments of `acqf_constructor` are used. In `get_NEI`, they are used in some preprocessing steps *before* constructing the acquisition function. They could also be directly passed to the botorch acquisition function, or not used at all --  all we need to do is return an `AcquisitionFunction`. We now give a bare-bones example of a custom factory function that returns our analytic scalarized-UCB acquisition.\n",
    "\n",
    "```python\n",
    "def get_scalarized_UCB(\n",
    "    model: Model,\n",
    "    objective_weights: Tensor,\n",
    "    **kwargs: Any,\n",
    ") -> AcquisitionFunction:\n",
    "    return ScalarizedUpperConfidenceBound(model=model, beta=0.2, weights=objective_weights)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By following the example shown in the [custom botorch model in ax](./custom_botorch_model_in_ax) tutorial, a `BotorchModel` can be instantiated with `get_scalarized_UCB` and then run in Ax."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
